// Copyright 2019 DeepMind Technologies Ltd. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import XCTest
import OpenSpiel

final class ExploitabilityTests: XCTestCase {
  func test2PlayerKuhnPokerPlayer0BestResponse() throws {
    let game = KuhnPoker(playerCount: 2)
    let bestResponse = BestResponse(
      game: game, player: .player(0), policy: UniformRandomPolicy(game))
    XCTAssertEqual(bestResponse.bestResponseActions, [
      "0": .bet, // Bet in case opponent folds when winning
      "1": .bet, // Bet in case opponent folds when winning
      "2": .pass, // Both equally good (we return the lowest action)
      "0pb": .pass, // Fold - we're losing
      "1pb": .bet, // Call - we're 50-50
      "2pb": .bet // Call - we've won
    ])
  }

  func test2PlayerKuhnPokerPlayer1BestResponse() throws {
    let game = KuhnPoker(playerCount: 2)
    let bestResponse = BestResponse(
      game: game, player: .player(1), policy: UniformRandomPolicy(game))
    XCTAssertEqual(bestResponse.bestResponseActions, [
      // Bet is always best
      "0p": .bet, "1p": .bet, "2p": .bet,
      // Call unless we know we're losing
      "0b": .pass, "1b": .bet, "2b": .bet
    ])
  }

  func test2PlayerKuhnPokerNashConv() throws {
    let game = KuhnPoker(playerCount: 2)
    XCTAssertEqual(game.nashConv(UniformRandomPolicy(game)), 11.0 / 12.0, accuracy: 1e-5)
  }

  func test2PlayerLeducPokerNashConv() throws {
    let game = LeducPoker(playerCount: 2)
    XCTAssertEqual(game.nashConv(UniformRandomPolicy(game)), 4.747222222222222, accuracy: 1e-5)
  }
}

extension ExploitabilityTests {
  static var allTests = [
    ("test2PlayerKuhnPokerPlayer0BestResponse", test2PlayerKuhnPokerPlayer0BestResponse),
    ("test2PlayerKuhnPokerPlayer1BestResponse", test2PlayerKuhnPokerPlayer1BestResponse),
    ("test2PlayerKuhnPokerNashConv", test2PlayerKuhnPokerNashConv),
    ("test2PlayerLeducPokerNashConv", test2PlayerLeducPokerNashConv),
  ]
}
